-
Notifications
You must be signed in to change notification settings - Fork 11.9k
feat: Hybrid unified/recurrent cache #13276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome to see this progress!
src/llama-kv-cache.cpp
Outdated
// TODO: Will it cause problems if some caches are able to remove the seq | ||
// but others aren't? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it will cause problems if this breaks the coherency between caches. (e.g. part of a sequence is removed in one cache but not the other).
This is what I was referring to in #12799 (comment) when I wrote:
The hardest part will be handling errors and properly keeping coherency between the different types of caches (because they don't necessarily roll-back states in the same way).
I think the seq_rm
API might fundamentally be too specific to self-attention KV cache. Recurrent models can't rollback their state, because intermediate states are not kept since keeping them for all tokens would take too much space. (when seq_rm
returns false, it means the states have to be re-calculated from scratch for the affected sequence (at least that was the intention in #5328))
Ideally, if there was some API to create snapshots and rollback to them, the implementation would be simpler for recurrent models (and for hybrid models by extension). (technically, sequences (with seq_id
) already kind of do this (and are copy-on-write), but snapshots within sequences might be more convenient to manage in user code, since managing which state is the latest per sequence could be done transparently)
But that would also mean having to manage the lifetime of explicit state snapshots (in examples/server/server.cpp
among others) instead of directly dealing with ranges of token positions (and might make things like largest-common-prefix context caching harder to handle). I've previously shared some ideas about state snapshots/checkpoints in #7531 (comment) (although the first half of the comment is about session restore as in state_read
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, interesting. I'm definitely still learning on-the-fly here, but based on this description and the logic here in server.cpp
, it seems like the most correct implementation would be to leak implementation details of the child caches or introduce a new member API for can_seq_rm
that is const
but returns the same logic. I think I'll give that a shot and see how far I can get.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I've pushed an attempt at doing this safely. One thing I noticed is that these mutating methods don't seem to have any sort of locking mechanism, so the way I have it implemented could certainly be prone to thread safety problems if concurrent threads tried to call seq_rm
. I don't think this is any different than for the current cache implementations since those would also be sensitive to the same races where the validated condition changes after validation but before the members get mutated, but I wanted to double check if this kind of thread safety is guarded against elsewhere (or just assumed to be handled in the client layer).
src/llama-kv-cache.cpp
Outdated
// If any of the caches are recurrent, require simple split | ||
return llama_sbatch(batch, m_hparams.n_embd, m_has_recurrent, logits_all); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simple split should not be used with recurrent models, they expect equal split.
See #7531 (comment) which illustrates the splits
// If any of the caches are recurrent, require simple split | |
return llama_sbatch(batch, m_hparams.n_embd, m_has_recurrent, logits_all); | |
// If any of the caches are recurrent, require non-simple split | |
return llama_sbatch(batch, m_hparams.n_embd, !m_has_recurrent, logits_all); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comment pointer, this is super helpful for understanding what the consequences of these actually are!
src/llama-kv-cache.cpp
Outdated
if (m_has_recurrent) { | ||
return sbatch.split_simple(n_ubatch); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will not work, recurrent models expect split_equal
to be used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I'm following now. I had them backwards in my head
// TODO: Is this correct? | ||
// If any children can shift, return true | ||
for (const auto & cache : m_children) { | ||
if (cache->get_can_shift()) { | ||
return true; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this should be if all children can shift, then return true.
But as you've noticed elsewhere, can_shift
should technically always be true for all currently-implemented cache types, so I don't know if that part of the API will stay anyway.
@compilade I now have a proof-point (#13550) that this works to some extend, though I haven't tested it robustly for edge cases. There are a few additional changes I needed to make on that branch that should maybe come over to this branch, but it gets a little hairy because they interact with adding the actual model architectures. Some possible paths forward:
Thoughts / Preferences? |
I thought about it further and decided that the cleanest separation between this and the Granite 4 branch is to pull over the key parts of |
src/llama-model.h
Outdated
@@ -402,7 +402,10 @@ struct llama_model { | |||
|
|||
// note: can mutate `cparams` | |||
// TODO: move this to new llm_arch_model_i interface | |||
llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const; | |||
llama_memory_i * create_memory( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just realized that this is entirely unnecessary since llama_model
already has hparams
as a member variable 🤦
@compilade @ggerganov I'm going to keep working on some unit tests for the different cache types and try to sort through the CI failure. In the mean time, would it be possible to get an initial review to make sure I'm on the right track? |
It looks like the failing tests were all on Windows and therefore fall into this note about not running tests using internal APIs on Windows due to the lack of |
I think the hybrid implementation can become similar to the new iSWA implementation - mainly for consistency. I would like to now simplify the
Will open a draft PR and we can discuss details there. Adding tests would be really useful - we can merge these from a separate PR without too much delays. |
@ggerganov Thank you for pointing that out! I'll work to rebase this change on the latest iSWA changes. I'll also try to pull out the unit tests into their own PR. |
@ggerganov Ported PR with minimal unit tests: #13669 I didn't try to expand them to incorporate the iSWA cache yet, but will look to do that as I go through the changes there and refactor this hybrid cache branch. |
These only test the basics so far, but should allow for more expansive tests to come. Branch: MemoryTests Signed-off-by: Gabe Goodhart <[email protected]>
These tests use private headers, so won't build on windows Branch: MemoryTests Signed-off-by: Gabe Goodhart <[email protected]>
This will be key for the hybrid cache which needs to be able to validate that all children can perform seq_rm cleanly before attempting to remove the seq from any single child to avoid ending up in a corrupted state. Branch: HybridCache Signed-off-by: Gabe Goodhart <[email protected]>
This will be needed by other cache types as well, so centralizing the definition will make it more reusable. Branch: HybridCache Signed-off-by: Gabe Goodhart <[email protected]>
Branch: HybridCache Signed-off-by: Gabe Goodhart <[email protected]>
Condensed from initial version https://github.com/gabe-l-hart/llama.cpp/tree/ec08571 The only difference is the removal of m_layer_cache_map which was unused and unnecessary now that child caches are instantiated with their own filters. Branch: HybridCache Signed-off-by: Gabe Goodhart <[email protected]>
Also, split llama_model_is_recurrent into llm_arch_is_recurrent in llama-arch with llama_model_is_recurrent delegating to llm_arch_is_recurrent. The same split is done for hybird. This is needed because there are places where the llama_model has not yet been initialized but we need to check if the model is recurrent (specifically for the per-layer recurrent check array in hparams). Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
Branch: GraniteFour
…s in hparams Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
…l is recurrent Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
This includes a slight architectural change where create_memory now only uses model architectures in the switch statement if their required cache type is not handled by llm_arch_is_[recurrent|hybrid]. Branch: HybridCache Signed-off-by: Gabe Goodhart <[email protected]>
The implementation of the hybrid cache intentionally does not specify the types of the child caches, so there was a naming mismatch with these predicate functions that used "hybrid" to imply "hybrid recurrent." Branch: HybridCache Signed-off-by: Gabe Goodhart <[email protected]>
@ggerganov This branch is now rebuilt based on the |
@@ -1898,6 +1916,11 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( | |||
v_l.reserve(n_layer); | |||
|
|||
for (int i = 0; i < n_layer; i++) { | |||
if (filter && !filter(i)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is buggy because push_back
is used below, so later when we index directly into the given per-layer tensor vectors, the final layers will be out-of-bounds reads.
Branch: HybridCache Signed-off-by: Gabe Goodhart <[email protected]>
Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
I've updated the Granite 4 / Bamba branch (#13550) to test this out and found a couple of bugs that should now be fixed and functional. |
Description
This implementation covers both
llama_memory_i
andllama_kv_cache
interfaces, but they could very well not be correct.Discussion
I'm putting this up for discussion even though it doesn't have much value as standalone. My ultimate goal is support for the just-released granite 4 which is a combination of
mamba2
andgranitemoeshared
layers. I opened #13275 to track the full scope of model architecture changes.